﻿using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using MathNet.Numerics.Random;
using MachineLearning;

namespace HLDA
{
    public class HldaDoc
    {
        public string title { get; set; }
        public List<HldaTopic> path { get; set; }
        public List<HldaWord> words { get; set; }
        public List<int> topicCount { get; set; }

        public int[,] f { get; set; }
        public double[] topicProb { get; set; }

        //precompute the stick breaking stuff for faster computation
        public double[] _stickLength { get; set; }
        public double[] _stickRemaining { get; set; }

        public HldaDoc(string title, int numTopic)
        {
            this.title = title;
            words = new List<HldaWord>();
            topicCount = new List<int>();
            for (int l = 0; l < numTopic; l++)
            {
                topicCount.Add(0);
            }
            _stickLength = new double[Global.maxLevel];
            _stickRemaining = new double[Global.maxLevel];
        }

        public void CalculateTopicProbability()
        {
            /*
            double[] stickLength = new double[Global.maxLevel];
            double[] stickRemaining = new double[Global.maxLevel];
            for (int l = 0; l < Global.maxLevel; l++)
            {
                double tmp = (Global.mpi + topicCount[l]) / (Global.pi + words.Count);
                stickLength[l] = tmp;
                stickRemaining[l] = 1 - tmp;
            }
            for (int k = 1; k < Global.maxLevel; k++)
            {
                stickRemaining[k] = stickRemaining[k] * stickRemaining[k - 1];
                stickLength[k] = stickLength[k] * stickRemaining[k - 1];
            }
            */
            topicProb = new double[Global.maxLevel];
            for (int k = 0; k < Global.maxLevel; k++)
            {
                topicProb[k] = (double)topicCount[k] / (double)words.Count;
            }
        }

        public void AddWord(HldaWord w)
        {
            words.Add(w);
        }

        public void AddWord(string w)
        {
            words.Add(new HldaWord(w));
        }

        public void UnassignPath()
        {
            foreach (HldaTopic t in path)
            {
                t.customers--;
                if (t.customers == 0)
                {
                    t.Remove();
                }
            }
            foreach (HldaWord word in words)
            {
                word.UnassignTopic();
            }
        }

        public void AssignPath(HldaTopic topic)
        {
            HldaTopic[] ppath = new HldaTopic[Global.maxLevel];
            int level = topic.level;
            HldaTopic current = topic;
            ppath[level] = current;
            for (int l = level - 1; l >= 0; l--)
            {
                current = current.parent;
                ppath[l] = current;
            }
            for (int l = level + 1; l < Global.maxLevel; l++)
            {
                ppath[l] = new HldaTopic(ppath[l - 1]);
            }
            foreach (HldaTopic t in ppath)
            {
                t.customers++;
            }
            this.path = ppath.ToList();
            foreach (HldaWord w in words)
            {
                w.topic = path[w.level];
                w.topic.wordCount[w.index]++;
            }
        }
    }

    public class HldaWord
    {
        public int level { get; set; }
        public HldaTopic topic { get; set; }
        public string text { get; set; }
        public int index { get; set; }

        public HldaWord(string text)
        {
            topic = null;
            this.text = text;
            index = -1;
            level = -1;
        }

        public void AssignTopic(HldaTopic topic)
        {
            this.topic = topic;
            topic.wordCount[index]++;
            level = topic.level;
        }

        public virtual void UnassignTopic()
        {
            topic.wordCount[index]--;
            topic = null;
        }
        /*
        public HldaTopic SampleLevel(HldaDoc doc)
        {
            int K = doc.path.Count;
            double m = Global.m;
            double pi = Global.pi;
            double mpi = Global.mpi;
            double eta = Global.eta;
            double Veta = Global.Veta;

            double[] wordLikelihood = new double[K];
            double[] stickLength = new double[K];

            doc.topicCount[level]--;
            doc.f[level, index]--;
            UnassignTopic();

            double levelCountTotal = doc.words.Count - 1; //doc.topicCount.Sum(); //this is the number of words in document - 1
            for (int k = 0; k < K; k++)
            {
                wordLikelihood[k] = Math.Log((doc.path[k].wordCount[index] + eta) / 
                    (doc.path[k].wordCount.Sum() + Veta));
                stickLength[k] = Math.Log((doc.topicCount[k] + 1) / (doc.topicCount.Sum() + 4));
                stickLength[k] += wordLikelihood[k];
            }
            double r = Global.random.NextDouble();
            level = Sampling.SampleLog(stickLength, r);
            AssignTopic(doc.path[level]);
            doc.topicCount[level]++;
            doc.f[level, index]++;
            return doc.path[level];
        }
        */
        
        public HldaTopic SampleLevel(HldaDoc doc)
        {
            int K = doc.path.Count;
            double m = Global.m;
            double pi = Global.pi;
            double mpi = Global.mpi;
            double eta = Global.eta;
            double Veta = Global.Veta;

            double[] wordLikelihood = new double[K];
            double[] stickLength = new double[K];
            double[] stickRemaining = new double[K];

            doc.topicCount[level]--;
            doc.f[level, index]--;
            UnassignTopic();

            double levelCountTotal = doc.words.Count - 1; //doc.topicCount.Sum(); //this is the number of words in document - 1
            for (int k = 0; k < K; k++)
            {
                wordLikelihood[k] = Math.Log((doc.path[k].wordCount[index] + eta) / (doc.path[k].wordCount.Sum() + Veta));
            }

            double oldLevel = level;
            double tmp1 = doc._stickLength[level];
            double tmp2 = doc._stickRemaining[level];

            double tmp = (mpi + doc.topicCount[level]) / (pi + levelCountTotal);
            doc._stickLength[level] = Math.Log(tmp);
            doc._stickRemaining[level] = Math.Log(1 - tmp);
            
            stickLength[0] = doc._stickLength[0];
            stickRemaining[0] = doc._stickRemaining[0];
            
            for (int k = 1; k < K; k++)
            {
                stickRemaining[k] = doc._stickRemaining[k] + stickRemaining[k - 1];
                stickLength[k] = doc._stickLength[k] + stickRemaining[k - 1];
            }
            for (int k = 0; k < K; k++)
            {
                stickLength[k] = stickLength[k] + wordLikelihood[k];
            }
            double r = Global.random.NextDouble();
            level = Sampling.SampleLog(stickLength, r);

            AssignTopic(doc.path[level]);
            doc.topicCount[level]++;
            doc.f[level, index]++;

            if (level == oldLevel)
            {
                doc._stickLength[level] = tmp1;
                doc._stickRemaining[level] = tmp2;
            }
            else
            {
                tmp = (mpi + doc.topicCount[level]) / (pi + levelCountTotal);
                doc._stickLength[level] = Math.Log(tmp);
                doc._stickRemaining[level] = Math.Log(1 - tmp);
            }
            return doc.path[level];
        }
    }
}
